{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "S3-tJZy35Ck_"
   },
   "source": [
    "# Trainer\n",
    "Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n",
    "\n",
    "<center>\n",
    "<img src=../_static/images/structure.svg></img>\n",
    "</center>\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ifsEQMzZ6mmz"
   },
   "source": [
    "## Usages\n",
    "In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in  on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XfsuU2AAE52C"
   },
   "source": [
    "### Pseudocode\n",
    "<center>\n",
    "<img src=../_static/images/pseudocode_off_policy.svg></img>\n",
    "</center>\n",
    "\n",
    "For the on-policy trainer, the main difference is that we clear the buffer after Line 10."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hcp_o0CCFz12"
   },
   "source": [
    "### Training without trainer\n",
    "As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n",
    "\n",
    "First, let us create the instances of Environment, ReplayBuffer, Policy and Collector."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "editable": true,
    "id": "do-xZ-8B7nVH",
    "slideshow": {
     "slide_type": ""
    },
    "tags": [
     "hide-cell",
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "import gymnasium as gym\n",
    "import torch\n",
    "\n",
    "from tianshou.data import Collector, VectorReplayBuffer\n",
    "from tianshou.env import DummyVectorEnv\n",
    "from tianshou.policy import PGPolicy\n",
    "from tianshou.utils.net.common import Net\n",
    "from tianshou.utils.net.discrete import Actor\n",
    "from tianshou.trainer import OnpolicyTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_env_num = 4\n",
    "buffer_size = (\n",
    "    2000  # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
    ")\n",
    "\n",
    "# Create the environments, used for training and evaluation\n",
    "env = gym.make(\"CartPole-v1\")\n",
    "test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
    "train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
    "\n",
    "# Create the Policy instance\n",
    "net = Net(\n",
    "    env.observation_space.shape,\n",
    "    hidden_sizes=[\n",
    "        16,\n",
    "    ],\n",
    ")\n",
    "actor = Actor(net, env.action_space.shape)\n",
    "optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
    "policy = PGPolicy(\n",
    "    actor=actor,\n",
    "    optim=optim,\n",
    "    dist_fn=torch.distributions.Categorical,\n",
    "    action_space=env.action_space,\n",
    "    action_scaling=False,\n",
    ")\n",
    "\n",
    "# Create the replay buffer and the collector\n",
    "replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
    "test_collector = Collector(policy, test_envs)\n",
    "train_collector = Collector(policy, train_envs, replaybuffer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wiEGiBgQIiFM"
   },
   "source": [
    "Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "JMUNPN5SI_kd",
    "outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d"
   },
   "outputs": [],
   "source": [
    "train_collector.reset()\n",
    "train_envs.reset()\n",
    "test_collector.reset()\n",
    "test_envs.reset()\n",
    "replaybuffer.reset()\n",
    "for i in range(10):\n",
    "    evaluation_result = test_collector.collect(n_episode=10)\n",
    "    print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n",
    "    train_collector.collect(n_step=2000)\n",
    "    # 0 means taking all data stored in train_collector.buffer\n",
    "    policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n",
    "    train_collector.reset_buffer(keep_statistics=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QXBHIBckMs_2"
   },
   "source": [
    "The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p-7U_cwgF5Ej"
   },
   "source": [
    "### Training with trainer\n",
    "The trainer does almost the same thing. The only difference is that it has considered many details and is more modular."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "vcvw9J8RNtFE",
    "outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5",
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "train_collector.reset()\n",
    "train_envs.reset()\n",
    "test_collector.reset()\n",
    "test_envs.reset()\n",
    "replaybuffer.reset()\n",
    "\n",
    "result = OnpolicyTrainer(\n",
    "    policy=policy,\n",
    "    train_collector=train_collector,\n",
    "    test_collector=test_collector,\n",
    "    max_epoch=10,\n",
    "    step_per_epoch=1,\n",
    "    repeat_per_collect=1,\n",
    "    episode_per_test=10,\n",
    "    step_per_collect=2000,\n",
    "    batch_size=512,\n",
    ").run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_j3aUJZQ7nml"
   },
   "source": [
    "## Further Reading\n",
    "### Logger usages\n",
    "Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n",
    "\n",
    "### Learn more about the APIs of Trainers\n",
    "[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "S3-tJZy35Ck_",
    "XfsuU2AAE52C",
    "p-7U_cwgF5Ej",
    "_j3aUJZQ7nml"
   ],
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
